import os
import math
import hashlib
import pickle

import torch
import numpy as np
import h5py

from ..qc.numpy import exact_spectrum
from ..qc.qiskit import QiskitBackend
from ..utils import circuit_param_size
from .bo import LeastSquaresWave, make_shift


BACKENDS = {
    'qiskit': QiskitBackend,
}


def arrhash(*args):
    hasher = hashlib.sha256()
    for arg in args:
        if isinstance(arg, (np.ndarray, torch.Tensor)):
            arr = np.array(arg)
            flag = arr.flags.writeable
            arr.flags.writeable = False
            hasher.update(arr.data)
            arr.flags.writeable = flag
        else:
            hasher.update(pickle.dumps(arg))
    return hasher.hexdigest()


def arrcache(fname, func, identifiers, keys='value'):
    if fname is None:
        return func()
    single = isinstance(keys, str)
    if single:
        keys = (keys,)

    identifier = arrhash(*identifiers)
    results = None

    if os.path.exists(fname):
        with h5py.File(fname, 'r') as fd:
            if identifier in fd:
                results = tuple(
                    torch.from_numpy(np.array(dset[()])) if dset.attrs.get('type', 'numpy') == 'torch' else dset[()]
                    for dset in (fd[f'{identifier}/{key}'] for key in keys)
                )

    if results is None:
        results = func()
        if single:
            results = (results,)
        try:
            with h5py.File(fname, 'a') as fd:
                for key, result in zip(keys, results):
                    fd[f'{identifier}/{key}'] = result
                    fd[f'{identifier}/{key}'].attrs['type'] = 'torch' if isinstance(result, torch.Tensor) else 'numpy'
        except OSError as error:
            raise RuntimeError(f'Unable to cache key \'{identifier}\'.') from error

    if single:
        results, = results
    return results


class DataSampler:
    def __init__(
        self,
        n_qbits,
        n_layers,
        j=(-1., 0., 0.),
        h=(0., 0., -1.),
        n_readout=1024,
        n_free_angles=None,
        rng=None,
        sector=1,
        noise_level=0.,
        prob_1to0=0.,
        prob_0to1=0.,
        pbc=False,
        circuit='esu2',
        backend='qiskit',
        cache_fname=None,
    ):
        self._n_circuit_params = circuit_param_size(circuit, n_layers)
        if n_free_angles is None:
            n_free_angles = self._n_circuit_params * n_qbits
        self.n_free_angles = min(self._n_circuit_params * n_qbits, n_free_angles)
        if rng is None:
            rng = np.random.default_rng()
        self.x_template = rng.uniform(0, 2 * math.pi, (1, self._n_circuit_params, n_qbits))
        self.rng = rng
        self.n_readout = n_readout

        self.backend = BACKENDS[backend](
            n_qbits=n_qbits,
            n_layers=n_layers,
            j=j,
            h=h,
            circuit=circuit,
            mom_sector=sector,
            pbc=pbc,
            rng=self.rng,
        )

        self.cache_fname = cache_fname
        self._exact_diag = None
        self._shot_var = None

    @property
    def n_angles(self):
        return self._n_circuit_params * self.backend.n_qbits

    def true_energy(self, angles, n_readout, var=None):
        mean, variance = self.backend(angles, n_readout)
        if isinstance(angles, np.ndarray):
            mean = torch.from_numpy(mean)
            variance = torch.from_numpy(variance)

        if var == 'measure':
            return mean, variance
        elif var == 'estimate':
            return mean, torch.full_like(mean, self.shot_var.item() / n_readout.item())
        elif var not in ('none', None):
            raise RuntimeError(f'No such variance mode: \'{var}\'')
        return mean

    def true_grad(self, angles, n_readout, var=None):
        grad = self.backend.parameter_shift_gradient(angles, n_readout)
        if not isinstance(angles, np.ndarray):
            grad = torch.from_numpy(grad)

        if var not in ('none', None):
            raise RuntimeError(f'Grad does not support variance mode: \'{var}\'')
        
        # if var == 'measure':
        #     return grad
        # elif var == 'estimate':
        #     return grad
        # elif var not in ('none', None):
        #     raise RuntimeError(f'No such variance mode: \'{var}\'')
        return grad

    def sample(self, n_samples=1000, known=True):
        # expand and fill in template
        x_data = self.x_template.repeat(n_samples, axis=0)
        x_data.reshape(
            x_data.shape[0], np.prod(x_data.shape[1:], dtype=int)
        )[:, :self.n_free_angles] = self.rng.uniform(0, 2 * math.pi, (n_samples, self.n_free_angles))

        retval = torch.from_numpy(x_data)
        # compute true values if known
        if known:
            retval = (retval, self.true_energy(x_data, self.n_readout))
        return retval

    def sample_linspace(self, n_samples=50, axes=(0, 0), known=True):
        # expand and fill in template with linspace
        x_data = self.x_template.repeat(n_samples, axis=0)
        x_data[(slice(None),) + axes] = torch.linspace(0, 2 * math.pi, n_samples)

        retval = torch.from_numpy(x_data)
        # compute true values if known
        if known:
            retval = (retval, self.true_energy(x_data, self.n_readout))
        return retval

    def estimate_variance(self, n_samples, n_readout=None):
        x_reg_est = self.sample(n_samples, known=False)
        _, y_reg_est_var = self.true_energy(x_reg_est, n_readout=n_readout, var='measure')
        return y_reg_est_var.mean(0)

    @property
    def shot_var(self):
        if self._shot_var is None:
            self._shot_var = self.estimate_shot_var(16, 8192)
        return self._shot_var

    def estimate_shot_var(self, n_samples, n_readout, force_compute=False):
        return arrcache(
            self.cache_fname if not force_compute else None,
            lambda: self.estimate_variance(n_samples, n_readout) * n_readout,
            ('shot_var', n_samples, n_readout, *self._cache_identifiers),
            keys='shot_var'
        )

    def exact_diag(self):
        if self._exact_diag is None:
            args = (
                int(self.backend.n_qbits),
                tuple(self.backend.j),
                tuple(self.backend.h),
                bool(self.backend.pbc),
            )
            self._exact_diag = arrcache(
                self.cache_fname,
                lambda: exact_diag(*args),
                args,
                keys=('true_e0', 'true_e1', 'true_wf')
            )
        return self._exact_diag

    def exact_overlap(self, angles):
        _, _, exact_wf = self.exact_diag()

        return self.backend.measure_overlap(angles, exact_wf)

    def exact_energy(self, angles):
        return self.true_energy(angles, n_readout=0)

    @property
    def _cache_identifiers(self):
        return (
            int(self.backend.n_qbits),
            int(self.backend.n_layers),
            tuple(self.backend.j),
            tuple(self.backend.h),
            int(self.n_readout),
            int(self.n_free_angles),
            int(self.backend.mom_sector),
            str(self.backend.name),
            str(self.backend.circuit),
            bool(self.backend.pbc),
        )

    def cached_sample(self, n_samples, key='train', force_compute=False):
        return arrcache(
            self.cache_fname if not force_compute else None,
            lambda: self.sample(n_samples, known=True),
            (key, n_samples, *self._cache_identifiers),
            keys=(f'x_{key}', f'y_{key}')
        )

    def exact_line(self, pivot, lins, k_dim, force_compute=False):
        shifts = (-math.tau / 3, math.tau / 3)
        lsw = LeastSquaresWave(shifts)
        x_pairs = make_shift(pivot, torch.tensor(shifts), k_dim)

        x_train = torch.cat([pivot[None], x_pairs])
        y_train = arrcache(
            self.cache_fname if not force_compute else None,
            lambda: self.exact_energy(x_train),
            ('exact_line', pivot, *shifts, k_dim),
            keys='y_train'
        )
        y_cand_meas = lsw(y_train[0], y_train[1:], lins)
        y_cand_var = self.shot_var / self.n_readout

        return y_cand_meas, y_cand_var


def exact_diag(n_qbits, j, h, pbc=True):
    exact_eigvals, exact_eigvecs = exact_spectrum(n_qbits, j, h, pbc=pbc)
    true_e0 = exact_eigvals[0]
    true_e1 = exact_eigvals[1]
    true_wf = exact_eigvecs[:, 0]
    return true_e0, true_e1, true_wf
